package com.neural;

import java.util.*;

/**
 * Java Neural Network Example
 * Handwriting Recognition
 * by Jeff Heaton (http://www.jeffheaton.com) 1-2002
 * -------------------------------------------------
 * Abstract base class for Neural Networks.
 *
 * @author Jeff Heaton (http://www.jeffheaton.com)
 * @version 1.0
 */

abstract public class Network {

    /**
     * The value to consider a neuron on
     */
    public final static double NEURON_ON = 0.9;

    /**
     * The value to consider a neuron off
     */
    public final static double NEURON_OFF = 0.1;

    /**
     * Output neuron activations
     */
    protected double output[];

    /**
     * Mean square error of the network
     */
    protected double totalError;

    /**
     * Number of input neurons
     */
    protected int inputNeuronCount;

    /**
     * Number of output neurons
     */
    protected int outputNeuronCount;

    /**
     * Random number generator
     */
    protected Random random = new Random(System.currentTimeMillis());

    /**
     * Called to learn from training sets.
     *
     * @exception java.lang.RuntimeException
     */
    abstract public void learn() throws RuntimeException;

    /**
     * Called to present an input pattern.
     *
     * @param input The input pattern
     */
    abstract void trial(double[] input);

    /**
     * Called to get the output from a trial.
     */
    double[] getOutput() {
        return output;
    }

    /**
     * Called to calculate the trial errors.
     *
     * @param train The training set.
     * @return The trial error.
     * @exception java.lang.RuntimeException
     */
    double calculateTrialError(TrainingSet train) throws RuntimeException {
        int i, size, tset, tclass;
        double diff;

        totalError = 0.0; // reset total error to zero

        // loop through all samples
        for (int t = 0; t < train.getTrainingSetCount(); t++) {
            // trial
            trial(train.getOutputSet(t));

            tclass = (int) (train.getClassify(train.getInputCount() - 1));
            for (i = 0; i < train.getOutputCount(); i++) {
                if (tclass == i) {
                    diff = NEURON_ON - output[i];
                }
                else {
                    diff = NEURON_OFF - output[i];
                }
                totalError += diff * diff;
            }

            for (i = 0; i < train.getOutputCount(); i++) {
                diff = train.getOutput(t, i) - output[i];
                totalError += diff * diff;

            }
        }

        totalError /= (double) train.getTrainingSetCount(); ;
        return totalError;
    }

    /**
     * Calculate the length of a vector.
     *
     * @param v vector
     * @return Vector length.
     */
    static double vectorLength(double v[]) {
        double rtn = 0.0;
        for (int i = 0; i < v.length; i++) {
            rtn += v[i] * v[i];
        }
        return rtn;
    }

    /**
     * Called to calculate a dot product.
     *
     * @param vec1 one vector
     * @param vec2 another vector
     * @return The dot product.
     */

    double dotProduct(double vec1[], double vec2[]) {
        int k, m, v;
        double rtn;

        rtn = 0.0;
        k = vec1.length / 4;
        m = vec1.length % 4;

        v = 0;
        while ( (k--) > 0) {
            rtn += vec1[v] * vec2[v];
            rtn += vec1[v + 1] * vec2[v + 1];
            rtn += vec1[v + 2] * vec2[v + 2];
            rtn += vec1[v + 3] * vec2[v + 3];
            v += 4;
        }

        while ( (m--) > 0) {
            rtn += vec1[v] * vec2[v];
            v++;
        }

        return rtn;
    }

    /**
     * Called to randomize weights.
     *
     * @param weight A weight matrix.
     */
    void randomizeWeights(double weight[][]) {
        double r;

        int temp = (int) (3.464101615 / (2. * Math.random())); // SQRT(12)=3.464...

        for (int y = 0; y < weight.length; y++) {
            for (int x = 0; x < weight[0].length; x++) {
                r = (double) random.nextInt(Integer.MAX_VALUE) +
                    (double) random.nextInt(Integer.MAX_VALUE) -
                    (double) random.nextInt(Integer.MAX_VALUE) -
                    (double) random.nextInt(Integer.MAX_VALUE);
                weight[y][x] = temp * r;
            }
        }
    }

}
